用Tensorflow和FastAPI构建图像分类API

您所在的位置:网站首页 python38 tensorflow 用Tensorflow和FastAPI构建图像分类API

用Tensorflow和FastAPI构建图像分类API

#用Tensorflow和FastAPI构建图像分类API| 来源: 网络整理| 查看: 265

首先,我们导入FastAPI类并创建一个对象应用程序。这个类有一些有用的参数,比如我们可以传递swaggerui的标题和描述。

from fastapi import FastAPIapp = FastAPI(title='Hello world')

我们定义一个函数并用@app.get. 这意味着我们的API/index支持GET方法。这里定义的函数是异步的,FastAPI通过为普通的def函数创建线程池来自动处理异步和不使用异步方法,并且它为异步函数使用异步事件循环。

@app.get('/index')async def hello_world(): return "hello world"图像识别API

我们将创建一个API来对图像进行分类,我们将其命名为predict/image。我们将使用Tensorflow来创建图像分类模型。

我们创建了一个函数load_model,它将返回一个带有预训练权重的MobileNet CNN模型,即它已经被训练为对1000个不同类别的图像进行分类。

import tensorflow as tfdef load_model(): model = tf.keras.applications.MobileNetV2(weights="imagenet") print("Model loaded") return modelmodel = load_model()

我们定义了一个predict函数,它将接受图像并返回预测。我们将图像大小调整为224x224,并将像素值规格化为[-1,1]。

from tensorflow.keras.applications.imagenet_utils import decode_predictions

decode_predictions用于解码预测对象的类名。这里我们将返回前2个可能的类。

def predict(image: Image.Image): image = np.asarray(image.resize((224, 224)))[..., :3] image = np.expand_dims(image, 0) image = image / 127.5 - 1.0 result = decode_predictions(model.predict(image), 2)[0] response = [] for i, res in enumerate(result): resp = {} resp["class"] = res[1] resp["confidence"] = f"{res[2]*100:0.2f} %" response.append(resp) return response

现在我们将创建一个支持文件上传的API/predict/image。我们将过滤文件扩展名以仅支持jpg、jpeg和png格式的图像。

我们将使用Pillow加载上传的图像。

def read_imagefile(file) -> Image.Image: image = Image.open(BytesIO(file)) return [email protected]("/predict/image")async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image) return prediction最终代码import uvicornfrom fastapi import FastAPI, File, UploadFilefrom application.components import predict, read_imagefileapp = FastAPI()@app.post("/predict/image")async def predict_api(file: UploadFile = File(...)): extension = file.filename.split(".")[-1] in ("jpg", "jpeg", "png") if not extension: return "Image must be jpg or png format!" image = read_imagefile(await file.read()) prediction = predict(image) return [email protected]("/api/covid-symptom-check")def check_risk(symptom: Symptom): return symptom_check.get_risk_level(symptom)if __name__ == "__main__": uvicorn.run(app, debug=True)



【本文地址】


今日新闻


推荐新闻


CopyRight 2018-2019 办公设备维修网 版权所有 豫ICP备15022753号-3